Skip to content

Enabled Weighted Sampling#635

Draft
mkolodner-sc wants to merge 14 commits into
mainfrom
mkolodner-sc/enable_weighted_sampling
Draft

Enabled Weighted Sampling#635
mkolodner-sc wants to merge 14 commits into
mainfrom
mkolodner-sc/enable_weighted_sampling

Conversation

@mkolodner-sc
Copy link
Copy Markdown
Collaborator

@mkolodner-sc mkolodner-sc commented May 12, 2026

Summary

Adds native weighted edge sampling to GiGL's distributed training pipeline via GLT's CPUWeightedSampler. When enabled, neighbors are sampled proportionally to edge weights rather than uniformly.

New API

  • DistPartitioner.register_edge_weights(edge_weights) — registers a 1D per-edge weight tensor (homogeneous or dict[EdgeType, Tensor] for heterogeneous) before calling partition_edge_index_and_edge_features(). Weights are partitioned alongside edge features in the same pass (co-partitioned, mirroring the node features + labels pattern).
  • load_torch_tensors_from_tf_record(weight_edge_feat_name=...) — accepts the name of an existing edge feature column to extract as sampling weights during TFRecord loading. The column is sliced out of the feature tensor and stored in LoadedGraphTensors.edge_weights; it is never duplicated in memory.
  • build_dataset(weight_edge_feat_name=...) — threads weight_edge_feat_name through to TFRecord loading and then calls register_edge_weights() with the extracted weights.
  • DistNeighborLoader(with_weight=True) / DistABLPLoader(with_weight=True) — enables weighted sampling. Defaults to False; must be set explicitly.
  • BaseDistLoader.validate_with_weight() — shared validation: raises ValueError if with_weight=True but no weights are registered in the dataset; raises NotImplementedError if used with PPRSamplerOptions (weight-proportional PPR residual propagation is deferred to a future PR).

Implementation notes

  • LoadedGraphTensors.edge_weights — new field carrying extracted weights from TFRecord loading through to register_edge_weights().
  • GraphPartitionData.weights (field already existed) carries the partitioned weight tensor to DistDataset._initialize_graph(), which forwards it to GLT's init_graph(edge_weights=...).
  • DistDataset.has_edge_weights property reflects whether weights were registered at construction time.
  • SamplingConfig.with_weight is now threaded through from the loader rather than hardcoded to False.
  • Graph Store mode: DistServer.get_edge_weights_registered() and RemoteDistDataset.fetch_edge_weights_registered() propagate has_edge_weights across the RPC boundary so compute nodes can validate with_weight against the remote dataset.

Tests

  • tests/unit/distributed/distributed_weighted_sampling_test.py (8 new tests):
    • Correctness (homogeneous + heterogeneous): weight=0 edges to "bad" nodes are never traversed in sampled subgraphs — verified by encoding node class in features and asserting no bad node appears after weighted sampling.
    • Partitioner edge cases: features only, weights only, neither, both (with consistency check that GraphPartitionData.edge_ids == FeaturePartitionData.ids), and heterogeneous partial weights (one edge type weighted, another not).
  • tests/unit/common/data/dataloaders_test.py (1 new test): test_load_edge_weights_from_tf_record — verifies that load_torch_tensors_from_tf_record correctly extracts a named column into edge_weights, removes it from edge_features, and returns the right shapes and values.

@mkolodner-sc mkolodner-sc changed the title [WIP] Enabled Weighted Sampling Enabled Weighted Sampling May 13, 2026
Copy link
Copy Markdown
Collaborator

@kmontemayor2-sc kmontemayor2-sc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Matt! Me and the robots did a first pass, it's possible they're imagining some of the issues here but I figured I'd flag :)

Comment thread gigl/common/data/load_torch_tensors.py Outdated
Comment thread gigl/distributed/dist_ablp_neighborloader.py Outdated
Comment thread gigl/distributed/dist_dataset.py Outdated
for edge_type, graph_partition_data in partitioned_edge_index.items()
if graph_partition_data.weights is not None
}
edge_weights = weights_by_type if weights_by_type else None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this var do anything? can we remove?

Comment thread gigl/distributed/dist_partitioner.py Outdated
"""Registers per-edge sampling weights to the partitioner.

Weights must be a 1-D float tensor of shape ``[num_edges]``, one scalar per edge.
Must be called after ``register_edge_index()`` and before ``partition()``.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this the case? Where does this fail otherwise?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it. just so we can check that they are the same shape later? Because we do self._edge_index is not None which implies we'd skip this check.

Comment thread gigl/distributed/dist_partitioner.py Outdated
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More robot findings:

Lines 188–330 — _partition_edge_index_and_edge_features ignores _edge_weights

The override never reads self._edge_weights, never partitions it, and never sets GraphPartitionData.weights. Because build_dataset_from_task_config_uri defaults to range partitioning (dataset_factory.py:643-646), any user who calls register_edge_weights while using the default config-driven path
silently loses all weights. Downstream DistDataset._has_edge_weights becomes False, and validate_with_weight fails with a confusing "no edge weights registered" error even though weights were registered.

Fix: Mirror DistPartitioner._partition_edge_index_and_edge_features. Roughly:

After line 223

weight_tensor = None
if self._edge_weights is not None and edge_type in self._edge_weights:
weight_tensor = self._edge_weights[edge_type]

Augment input_data on lines 229/234 to include weight_tensor (e.g. (..., edge_feat, weight_tensor))

After _partition_by_chunk on line 246, slice partitioned weights back out of res_list

Set GraphPartitionData(..., weights=partitioned_weights) in both branches at lines 308/320

del self._edge_weights[edge_type] and clear when empty, mirroring the edge_feat cleanup at 253-262

Add a regression test for DistRangePartitioner + register_edge_weights.

],
rank: int,
tf_dataset_options: TFDatasetOptions = TFDatasetOptions(),
weight_edge_feat_name: Optional[Union[str, dict[EdgeType, str]]] = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we update the doc string for the new arg?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also it seems a little specific to put the weight_edge_feat_name here when _data_loading_process is kind of a generic function? Or do you think that's fine for now? I'm not sure how else we'd address this.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More robot findings:

Lines 606–621 — Warning promises uniform fallback for partial heterogeneous weights, but GLT will crash

weights_by_type = {
edge_type: graph_partition_data.weights

  for edge_type, graph_partition_data in partitioned_edge_index.items()
  if graph_partition_data.weights is not None

}
edge_weights = weights_by_type if weights_by_type else None
if weights_by_type:

  missing = set(partitioned_edge_index.keys()) - set(weights_by_type.keys())
  if missing:

      logger.warning(
          f"... When with_weight=True, edge types without weights "

          f"will fall back to uniform sampling."
      )

GLT does not fall back to uniform sampling in this case. graphlearn_torch/sampler/neighbor_sampler.py:104-113 unconditionally instantiates pywrap.CPUWeightedSampler for every edge type when with_weight=True. For unweighted types, Topology keeps edge_weights = torch.empty(0), Graph::InitCPUGraphFromCSR
leaves edge_weight_ unset, and CPUWeightedSampler::Sample dereferences a null prob pointer in std::discrete_distribution<> — undefined behavior, typically a crash.

Fix: Either tighten BaseDistLoader.validate_with_weight to require weights for every edge type when with_weight=True, or synthesize all-ones weight tensors here for missing edge types:

if weights_by_type:

  for edge_type, gpd in partitioned_edge_index.items():
      if gpd.weights is None:

if weights_by_type:
missing = set(partitioned_edge_index.keys()) - set(weights_by_type.keys())
if missing:
logger.warning(
f"... When with_weight=True, edge types without weights "
f"will fall back to uniform sampling."
)

GLT does not fall back to uniform sampling in this case. graphlearn_torch/sampler/neighbor_sampler.py:104-113 unconditionally instantiates pywrap.CPUWeightedSampler for every edge type when with_weight=True. For unweighted types, Topology keeps edge_weights =
torch.empty(0), Graph::InitCPUGraphFromCSR leaves edge_weight_ unset, and CPUWeightedSampler::Sample dereferences a null prob pointer in std::discrete_distribution<> — undefined behavior, typically a crash.

Fix: Either tighten BaseDistLoader.validate_with_weight to require weights for every edge type when with_weight=True, or synthesize all-ones weight tensors here for missing edge types:

if weights_by_type:
for edge_type, gpd in partitioned_edge_index.items():
if gpd.weights is None:
weights_by_type[edge_type] = torch.ones(gpd.edge_index.shape[1], dtype=torch.float32)
edge_weights = weights_by_type

Either way, rewrite the warning to reflect actual behavior. Add a test with one weighted + one unweighted sampled edge type and with_weight=True.

Comment thread gigl/common/data/load_torch_tensors.py Outdated
f"weight_edge_feat_name '{col_name}' not found in edge feature keys "
f"for edge type {edge_type}: {feature_keys}"
)
col_idx = feature_keys.index(col_name)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Robot finding - is this true? Do we support multi dim features?

Lines 185, 207 — feature_keys.index(col_name) assumes each feature key is exactly one column wide

col_idx = feature_keys.index(col_name)
weights[edge_type] = feat_tensor[:, col_idx]
keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_idx]
features[edge_type] = feat_tensor[:, keep_cols]

_concatenate_features_by_names (gigl/common/data/dataloaders.py:155-187) supports multi-dim features, so a single feature key may contribute multiple columns. Example: feature_keys = ["embedding", "weight"] where embedding has shape [N, 16] and weight has shape [N, 1] →
concat tensor has 17 columns and the weight lives at column 16, not column 1. tests/unit/common/data/dataloaders_test.py:494 only covers scalar features so doesn't catch this.

Fix: Compute the actual column offset by summing the widths of preceding features (from feat_spec[feature_key].shape or by deriving from the concatenated tensor's per-feature widths), and assert the weight feature has width 1 before squeezing:

col_widths = [serialized_tf_record_info[edge_type].feature_spec[k].shape[-1] or 1 for k in feature_keys]
col_offset = sum(col_widths[: feature_keys.index(col_name)])
weight_width = col_widths[feature_keys.index(col_name)]
assert weight_width == 1, f"weight column '{col_name}' must be width 1, got {weight_width}"
weights[edge_type] = feat_tensor[:, col_offset]
keep_cols = [i for i in range(feat_tensor.shape[1]) if i != col_offset]
features[edge_type] = feat_tensor[:, keep_cols]

Add a test with at least one multi-dim feature alongside the weight column.

Comment thread gigl/common/data/load_torch_tensors.py Outdated
@mkolodner-sc
Copy link
Copy Markdown
Collaborator Author

/unit_test

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 14, 2026

GiGL Automation

@ 23:03:34UTC : 🔄 Python Unit Test started.

@ 24:11:45UTC : ❌ Workflow failed.
Please check the logs for more details.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 14, 2026

GiGL Automation

@ 23:03:35UTC : 🔄 C++ Unit Test started.

@ 23:05:34UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 14, 2026

GiGL Automation

@ 23:03:35UTC : 🔄 Scala Unit Test started.

@ 23:13:48UTC : ✅ Workflow completed successfully.

@mkolodner-sc
Copy link
Copy Markdown
Collaborator Author

/unit_test

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 15, 2026

GiGL Automation

@ 06:09:08UTC : 🔄 Scala Unit Test started.

@ 06:19:33UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 15, 2026

GiGL Automation

@ 06:09:09UTC : 🔄 C++ Unit Test started.

@ 06:13:19UTC : ✅ Workflow completed successfully.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 15, 2026

GiGL Automation

@ 06:09:09UTC : 🔄 Python Unit Test started.

@ 07:13:24UTC : ✅ Workflow completed successfully.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants